{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To save time, start this download first, before reading through the examples.\n",
    "import torch, torchvision, os\n",
    "if not os.path.isfile('datasets/miniplaces/train/yard/00001000.jpg'):\n",
    "    torchvision.datasets.utils.download_and_extract_archive(\n",
    "        'http://dissect.csail.mit.edu/datasets/miniplaces.zip',\n",
    "        'datasets', md5='bfabeb497c7eca01c74cd8441a9ac108')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"dataloader.png\" style=\"max-width:100%\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets and Dataloaders in pytorch\n",
    "===================================\n",
    "\n",
    "Data sets can be thought of as big arrays of data.  If the data set is small enough (e.g., MNIST, which has 60,000 28x28 grayscale images), a dataset can be literally represented as an array - or more precisely, as a single pytorch tensor.  With one number per pixel, MNIST takes about 200 megabytes of RAM, which fits comfortably into a modern computer.\n",
    "\n",
    "But larger-scale datasets like ImageNet or Places365 have more than a million higher-resolution full-color images.  In these cases, an ordinary python array or pytorch tensor would require more than a terabyte of RAM, which is impractical on most computers.\n",
    "\n",
    "Instead, we need to load the data from disk (or SSD).  Unfortunately, the latency of loading from disk is very slow compared to RAM, so we need to do the loading cleverly if we want to load the data quickly.\n",
    "\n",
    "To solve the problem, pytorch provides two classes:\n",
    " * `torch.utils.data.Dataset` - This very simple base class represents an array where the actual data may be slow to fetch, typically because the data is in disk files that require some loading, decoding, or other preprocessing. Pytorch provides a variety of different `Dataset` subclasses.  For example, there is a handy one called `ImageFolder` that treats a directory tree of image files as an array of classified images.\n",
    " * `torch.utils.data.DataLoader` - This fancy class wraps a `Dataset` as a stream of data batches.  Behind the scenes it uses a few techniques to feed the data faster.  You do not need to subclass `DataLoader` - its purpose is to make a `Dataset` speedy."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Looking at an image data set using ImageFolder\n",
    "\n",
    "The most common `Dataset` used in computer vision is `ImageFolder`, which loads a set of image files from a directory tree.  It treats every subdirectory of images as a classification category.  To demonstrate it, we will use it to load images from the miniplaces dataset loaded above.\n",
    "\n",
    "**Directory layout.** Notice that `datasets/miniplaces/val` contains a set of 100 directories with names like `golf_course`.  Each of these directories contains 100 images, each stored as a jpeg file: 10000 images in total."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ls datasets/miniplaces/val/golf_course"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Constructing an ImageFolder.**  Making an ImageFolder at the root directory of the data set creats an object that behaves like an array: it has a length, and each entry contains a tuple with an image and a number.  The image is stored as a `PIL` object which is a standard python object for images, and the number denotes the classification class - with one class for each folder, numbered in alphabetical order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_set = torchvision.datasets.ImageFolder('datasets/miniplaces/val')\n",
    "print('Length is', len(val_set))\n",
    "item = val_set[5100]\n",
    "print('5100th item is a pair', item)\n",
    "\n",
    "# Display the PIL image and the class name directly.\n",
    "display(item[0])\n",
    "print('Class name is', val_set.classes[item[1]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Transforming the PIL image into a pytorch tensor.**  A PIL image is not convenient for training: we would prefer our data set to return pytorch tensors.  So we can tell `ImageFolder` to do this by specifying the `transform` function on construction.  Pytorch comes with a standard transform function `torchvision.transforms.ToTensor()` which converts an image to a pytorch tensor.\n",
    "\n",
    "Now when indexing into the data set, we will get a pytorch tensor instead of a PIL image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_set =  torchvision.datasets.ImageFolder(\n",
    "    'datasets/miniplaces/val',\n",
    "    transform=torchvision.transforms.ToTensor())\n",
    "print(val_set[1000])\n",
    "\n",
    "# There is an inverse transform that can be used to convert it back to a PIL image,\n",
    "# handy if we want to see it.\n",
    "as_image = torchvision.transforms.ToPILImage()\n",
    "display(as_image(val_set[1000][0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fast Dataset Access using DataLoader\n",
    "\n",
    "When we use a dataset for training, we will usually run through the whole dataset in batches.  We could do this ourselves, as in line 6-8 below, by just fetching the images one at a time and grouping them.\n",
    "\n",
    "But a faster way to iterate through the dataset is to wrap our `val_set` object in a `torch.utils.data.DataLoader` object, as shown on line 14-18 below.  The `val_loader` we get can magically pull data out of the Dataset much  faster than doing it in the smiple way; the `DataLoader` class does this by using several threads to load and prefetch the data.\n",
    "\n",
    "The speedup will depend on the system and the number of threads you use (the number of threads to use is specified using `num_workers`).  In practice using `DataLoader` will typically be 5-20 times faster than direct `Dataset` access."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "print('Going over the data set as an array.')\n",
    "start = time.time()\n",
    "summed_image_dataset = 0\n",
    "batch_size = 100\n",
    "for i in range(0, len(val_set), batch_size):\n",
    "    image_batch = torch.stack([val_set[i+j][0] for j in range(batch_size)])\n",
    "    summed_image_dataset += image_batch.sum(0)\n",
    "end = time.time()\n",
    "print(f'Took {end - start} seconds')\n",
    "\n",
    "print('Going over the same dataset using a dataloader.')\n",
    "start = time.time()\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_set, batch_size=batch_size, num_workers=10)\n",
    "summed_image_loader = 0\n",
    "for image_batch, label_batch in val_loader:\n",
    "    summed_image_loader += image_batch.sum(0)\n",
    "end = time.time()\n",
    "print(f'Took {end - start} seconds')\n",
    "\n",
    "print('Numerical difference is exactly', (summed_image_loader - summed_image_dataset).norm().item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise\n",
    "\n",
    "1. Try adjusting `num_workers` down to 1 and up to 100.  How does this affect the speed?\n",
    "2. Try changing `batch_size` down to 1 or up to 1000.\n",
    "\n",
    "**Note**: the speed differences you see will depend on the specifics of your system setup.\n",
    "If you are running on Google Colab, you may not see much of a speedup from DataLoader.\n",
    "This is because Colab provides a very low-latency virtual disk (so direct Dataset access\n",
    "is faster than on a regular computer), and a virtual CPU with very slow concurrency\n",
    "(so DataLoader multithreading is slower than normal)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: copy the code above and alter:\n",
    "# 1. num_workers and note the changes in speed\n",
    "# 2. batch_size and note the changes in speed."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Other common dataloader tricks.**  `DataLoader` can do a few more useful things.\n",
    "\n",
    " * Although a DataLoader does not put batches on the GPU directly (because of multithreading limitations), it *can* put the batch in pinned memory, which is faster to copy to the GPU later after you get it out of the DataLoader.  Make the DataLoader with `pin_memory=True` for this.\n",
    " * During training you usually do not want the batches in alphabetical order.  The DataLoader can shuffle the batches so that they are randomized, instead of sequential.  `shuffle=True` for this.\n",
    "\n",
    " \n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using a DataLoader for Training\n",
    "\n",
    "We can put everything together by using the data from a data loader to train a classifier.\n",
    "\n",
    "The following is a simplistic example of training an image classifier.  It uses the Adam optimizer and the ResNet-18 neural network architecture, and trains for a couple minutes, just passing once over the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "# Create a Dataset of miniplaces training images.\n",
    "train_set = torchvision.datasets.ImageFolder(\n",
    "    'datasets/miniplaces/train',\n",
    "    torchvision.transforms.ToTensor())\n",
    "\n",
    "# Wrap the Dataset in a high-speed DataLoader with batch_size 100.\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_set, batch_size=100, num_workers=10,\n",
    "    shuffle=True,\n",
    "    pin_memory=True)\n",
    "\n",
    "# Create an untrained neural network using the ResNet 18 architecture.\n",
    "model = torchvision.models.resnet18(num_classes=100).cuda()\n",
    "\n",
    "# Set up the model for training using the Adam optimizer.\n",
    "model.train()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "# To train, optimize an objective on batches of training data.\n",
    "# Here we look at every training image once.\n",
    "for batch in tqdm(train_loader):\n",
    "    images, labels = [d.cuda() for d in batch]\n",
    "    optimizer.zero_grad()\n",
    "    scores = model(images.cuda())\n",
    "    loss = torch.nn.functional.cross_entropy(scores, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checking Accuracy with a Held-Out Dataset\n",
    "\n",
    "To check if network has learned anything useful, we can check whether the model can make good predictions on unseen images.  The easy way to do this is to create a second `ImageFolder` dataset (and `DataLoader`) with a second set of images that was **not** used for training.\n",
    "\n",
    "While the achieved accuracy after a couple minutes of training is not perfect, it is already much better than random."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a validation dataset and data loader.\n",
    "val_set = torchvision.datasets.ImageFolder(\n",
    "    'datasets/miniplaces/val',\n",
    "    torchvision.transforms.ToTensor())\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_set, batch_size=100, num_workers=10,\n",
    "    pin_memory=True)\n",
    "\n",
    "# This function runs over the validation images and counts accurate predictions.\n",
    "def accuracy():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    for iter, batch in enumerate(val_loader):\n",
    "        images, labels = [d.cuda() for d in batch]\n",
    "        with torch.no_grad():\n",
    "            scores = model(images.cuda())\n",
    "        correct += (scores.max(1)[1] == labels).float().sum()\n",
    "    return correct.item() / len(val_set)\n",
    "\n",
    "print(f'Accuracy on unseen images {accuracy() * 100}% (random guesses would be 1%)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise\n",
    "\n",
    "1. For every 10th batch, display the first image in the batch.\n",
    "2. Also print the predicted class name and the true class name for that image.\n",
    "\n",
    "Hints:\n",
    "* Use the `as_image` function defined in a previous cell.\n",
    "* Use `images[0].cpu()` to move the image to the CPU before displaying it.\n",
    "* The prediction of the network for the 0th item of the batch is `scores.max(1)[1][0]`\n",
    "* Use `val_set.classes[pred]` to convert the numerical prediction to a readable label."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Improving Training using Data Augmentation\n",
    "\n",
    "One of the main ways to stretch a data set to make it more effective for training is to randomly adjust the images.  For example if we randomly adjust the crop, color, or orientation of the image while loading, using the same image file multiple times will produce different training examples for the network.  This is an easy way to increase the amount of training diversity in the data set without requring more actual images.\n",
    "\n",
    "To do data augmentation in a pytorch `Dataset`, you can specify more operations on `transform=` besides `ToTensor()`.\n",
    "\n",
    "In particular, there is a `Compose` transform that makes it easy to chain a series of data transformations; and `torchvision.transforms` includes a number of useful image transforms such as random resized crops and image flips.\n",
    "\n",
    "Here is an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a Dataset of miniplaces training images.\n",
    "train_set = torchvision.datasets.ImageFolder(\n",
    "    'datasets/miniplaces/train',\n",
    "    torchvision.transforms.Compose([\n",
    "        torchvision.transforms.RandomCrop(112),\n",
    "        torchvision.transforms.RandomHorizontalFlip(),\n",
    "        torchvision.transforms.ToTensor(),\n",
    "    ]))\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_set, batch_size=100, num_workers=10,\n",
    "    shuffle=True,\n",
    "    pin_memory=True)\n",
    "\n",
    "# Now let's train for one more epoch, and test the accuracy\n",
    "model.train()\n",
    "for batch in tqdm(train_loader):\n",
    "    images, labels = [d.cuda() for d in batch]\n",
    "    optimizer.zero_grad()\n",
    "    scores = model(images.cuda())\n",
    "    loss = torch.nn.functional.cross_entropy(scores, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "print(f'Accuracy on unseen images {accuracy() * 100}% (random guesses would be 1%)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise\n",
    "\n",
    "1. Print out the same images as before, with updated predictions for the newly tuned network parameters.\n",
    "2. Repeat training for a few more epochs.  How does the accuracy evolve?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Epilog\n",
    "\n",
    "Almost all the pytorch code you will find will be variations and extensions of the patterns we have covered.  You're ready to explore.\n",
    "\n",
    "Have fun!\n",
    "\n",
    "### [Back to the introduction &rightarrow;](1-Pytorch-Introduction.ipynb)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
